import numpy as np
import faiss


class FaissKNeighbors(object):
    def __init__(self, k=5):
        self.index = None
        self.y = None
        self.X = None
        self.k = k

    def load(self, path_x, path_y):
        self.X = np.load(path_x)
        y = np.load(path_y)
        self.fit(self.X, y)

    def fit(self, X, y):
        self.index = faiss.IndexFlatL2(X.shape[1])
        self.index.add(X.astype(np.float32))
        self.y = y

    def predict(self, X):
        distances, indices = self.index.search(X.astype(np.float32), k=self.k)
        return self.X[indices[0]], self.y[indices[0]]